## Regression models of placement outcomes
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✔ ggplot2 3.3.5     ✔ purrr   0.3.4
## ✔ tibble  3.1.5     ✔ dplyr   1.0.7
## ✔ tidyr   1.1.4     ✔ stringr 1.4.0
## ✔ readr   2.0.1     ✔ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(tidylog)
## 
## Attaching package: 'tidylog'
## The following objects are masked from 'package:dplyr':
## 
##     add_count, add_tally, anti_join, count, distinct, distinct_all,
##     distinct_at, distinct_if, filter, filter_all, filter_at, filter_if,
##     full_join, group_by, group_by_all, group_by_at, group_by_if,
##     inner_join, left_join, mutate, mutate_all, mutate_at, mutate_if,
##     relocate, rename, rename_all, rename_at, rename_if, rename_with,
##     right_join, sample_frac, sample_n, select, select_all, select_at,
##     select_if, semi_join, slice, slice_head, slice_max, slice_min,
##     slice_sample, slice_tail, summarise, summarise_all, summarise_at,
##     summarise_if, summarize, summarize_all, summarize_at, summarize_if,
##     tally, top_frac, top_n, transmute, transmute_all, transmute_at,
##     transmute_if, ungroup
## The following objects are masked from 'package:tidyr':
## 
##     drop_na, fill, gather, pivot_longer, pivot_wider, replace_na,
##     spread, uncount
## The following object is masked from 'package:stats':
## 
##     filter
library(readxl)
library(cowplot)
library(broom)
library(forcats)
library(rstanarm)
## Loading required package: Rcpp
## This is rstanarm version 2.21.1
## - See https://mc-stan.org/rstanarm/articles/priors for changes to default priors!
## - Default priors may change, so it's safest to specify priors, even if equivalent to the defaults.
## - For execution on a local, multicore CPU with excess RAM we recommend calling
##   options(mc.cores = parallel::detectCores())
## As of 2020-05-20, some kind of mismatch btwn parallel and Rstudio causes a "freeze" when using multiple cores
## <https://github.com/rstudio/rstudio/issues/6692>
# options(mc.cores = min(4, 
#                        parallel::detectCores() - 2))
options(mc.cores = 4)
## bayesplot makes itself the default theme
theme_set(theme_minimal())

library(tictoc)
library(assertthat)
## 
## Attaching package: 'assertthat'
## The following object is masked from 'package:tibble':
## 
##     has_name
## Suppress messages when generating HTML file
knitr::opts_chunk$set(message = FALSE)

source('../R/predictions.R')
source('../R/posterior_estimates.R')

## Use command-line argument to check whether to force resampling
library(argparse)
parser = ArgumentParser()
parser$add_argument("-f", "--force", 
                    action="store_true", 
                    default=FALSE,
                    help="Force resampling the regression model")
args = parser$parse_args()
force_resampling = args$force
## Or uncomment the next line
# force_resampling = TRUE
data_folder = '../data/'
output_folder = '../output/04_'
paper_folder = '../paper/'

# cluster_distances = read_csv(str_c(data_folder, 
#                                    '00_k9distances_2019-03-15.csv')) %>% 
#     count(cluster = cluster, average_distance = avgDist) %>% 
#     mutate(cluster = as.character(cluster))
# 
# ggplot(cluster_distances, aes(cluster, scale(average_distance))) +
#     geom_label(aes(label = n, fill = n, size = n), color = 'white')

load(str_c(data_folder, '02_parsed.Rdata'))
## Programs with data validated in summer 2021
## 129
validated = read_excel(file.path(data_folder, '00_DataChecks2021APDA.xlsx')) |> 
    filter(`Placement Page?` == 1 | `Dissertation Records/ProQuest?` == 1)

univ_df = read_rds(str_c(data_folder, '03_univ_net_stats.rds'))

individual_df = individual_df %>%
    left_join(univ_df, by = c('placing_univ_id' = 'univ_id')) %>%
    filter(placing_univ_id %in% validated$`University ID`) %>%
    ## Use the canonical names from univ_df
    select(-placing_univ) %>%
    ## Drop NAs
    # filter(complete.cases(.))
    filter_at(vars('permanent', 'aos_category', 
                   'graduation_year', 'prestige', 
                   'community', 'cluster_label',
                   'gender', 'frac_w', 
                   'frac_high_prestige', 'total_placements'), 
              all_vars(negate(is.na)(.))) %>%
    rename(cluster = cluster_label) %>% 
    mutate(perc_w = 100*frac_w, 
           perc_high_prestige = 100*frac_high_prestige)

individual_df
## Variables to consider: aos_category; graduation_year; placement_year; prestige; out_centrality; cluster; community; placing_univ_id; gender; country; perc_w; total_placements

## Overall permanent placement rate
count(individual_df, permanent) %>% 
    mutate(share = n / sum(n))
## Giant pairs plot/correlogram ----
## perc_high_prestige, out_centrality, and prestige are all tightly correlated
## All other pairs have low to moderate correlation
individual_df %>% 
    select(permanent, aos_category, aos_diversity, perc_high_prestige,
           graduation_year, placement_year, prestige, 
           in_centrality, out_centrality, #community, 
           cluster, #average_distance,
           gender, country, perc_w, 
           total_placements) %>% 
    mutate_if(negate(is.numeric), function(x) as.integer(as.factor(x))) %>% 
    mutate_at(vars(in_centrality, out_centrality), log10) %>% 
    # GGally::ggpairs()
    cor() %>% 
    as_tibble(rownames = 'Var1') %>% 
    gather(key = 'Var2', value = 'cor', -Var1) %>% 
    ggplot(aes(Var1, Var2, fill = cor)) +
    geom_tile() +
    geom_text(aes(label = round(cor, digits = 2)), 
              color = 'white') +
    scale_fill_gradient2()

## No indication that AOS diversity has any effect
ggplot(individual_df, aes(aos_diversity, 1*permanent)) + 
    geom_point() +
    geom_smooth(method = 'loess')

## And not for fraction of PhDs awarded to women women, either
ggplot(individual_df, aes(frac_w, 1*permanent)) +
    geom_point() +
    geom_smooth(method = 'loess')

## Descriptive statistics ----
## Individual-level variables (all discrete)
desc_1_plot = individual_df %>%
    select(permanent, aos_category, 
           graduation_year, placement_year, 
           gender) %>%
    gather(key = variable, value = value) %>%
    count(variable, value) %>% 
    mutate(variable = str_replace_all(variable, '_', ' ')) %>% 
    ggplot(aes(fct_rev(value), n, group = variable)) +
    geom_col(aes(fill = variable), show.legend = FALSE) +
    scale_fill_brewer(palette = 'Set1') +
    xlab('') +
    coord_flip() +
    facet_wrap(vars(variable), scales = 'free', ncol = 3)
## Warning: attributes are not identical across measure variables;
## they will be dropped
desc_1_plot

ggsave(str_c(output_folder, 'descriptive_1.png'), 
       desc_1_plot, 
       height = 2*2, width = 2*3, scale = 1.5)

## Program-level categorical
desc_2_plot = individual_df %>%
    select(prestige, country, 
           #community, 
           cluster) %>%
    gather(key = variable, value = value) %>%
    count(variable, value) %>% 
    ggplot(aes(fct_rev(value), n, group = variable)) +
    geom_col(aes(fill = variable), show.legend = FALSE) +
    scale_fill_viridis_d() +
    xlab('') +
    coord_flip() +
    facet_wrap(vars(variable), scales = 'free', ncol = 3)
desc_2_plot

ggsave(str_c(output_folder, 'descriptive_2.png'), 
       desc_2_plot, 
       height = 1*2, width = 2*2, scale = 1.5)

## Program-level continuous variables
# individual_df %>%
#     select(frac_w, total_placements, perm_placement_rate) %>%
#     gather(key = variable, value = value) %>%
#     group_by(variable) %>%
#     summarize_at(vars(value), 
#                  funs(min, max, mean, median, sd), 
#                  na.rm = TRUE)

program_cont = individual_df %>% 
    mutate(in_centrality = log10(in_centrality)) %>% 
    select(`women share` = frac_w, 
           `total placements` = total_placements, 
           `permanent placement rate` = perm_placement_rate, 
           `AOS diversity (bits)` = aos_diversity,
           `hiring centrality (log10)` = in_centrality) %>% 
    gather(key = variable, value = value)

desc_3_plot = ggplot(program_cont, aes(value)) +
    geom_density() +
    geom_rug() +
    geom_vline(data = {program_cont %>% 
            group_by(variable) %>% 
            summarize(mean = mean(value))}, 
            aes(xintercept = mean, 
                color = 'mean')) +
    geom_vline(data = {program_cont %>% 
            group_by(variable) %>% 
            summarize(median = median(value))}, 
            aes(xintercept = median, 
                color = 'median')) +
    scale_color_brewer(palette = 'Set1', 
                       name = 'summary\nstatistic') +
    facet_wrap(~ variable, scales = 'free', ncol = 3) +
    theme(legend.position = 'bottom')
desc_3_plot

ggsave(str_c(output_folder, 'descriptive_3.png'), 
       desc_3_plot, 
       height = 2*2, width = 2*3.5, scale = 1.5)

plot_grid(desc_1_plot, 
          desc_2_plot, 
          desc_3_plot, 
          align = 'v', axis = 'lr', ncol = 1, 
          labels = 'auto',
          hjust = -7
          )

ggsave(str_c(output_folder, 'descriptive.png'), 
       height = 4*3, width = 3*3, scale = 1)
ggsave(str_c(paper_folder, 'fig_descriptive.png'), 
       height = 4*3, width = 3*3, scale = 1)



## Model -----
model_file = str_c(data_folder, '04_model.Rds')
if (!file.exists(model_file) || force_resampling) {
    ## ~700 seconds
    tic()
    model = individual_df %>% 
        mutate(prestige = fct_relevel(prestige, 'low-prestige'), 
               country = fct_relevel(country, 'U.S.')) %>% 
        stan_glmer(formula = permanent ~ 
                       (1|aos_category) +
                       gender + 
                       (1|graduation_year) +
                       (1|placement_year) +
                       1 +
                       aos_diversity +
                       # (1|community) +
                       (1|cluster) +
                       # average_distance +
                       log10(in_centrality) +
                       total_placements +
                       perc_w +
                       country +
                       prestige,
                   family = 'binomial',
                   ## Priors
                   ## Constant and coefficients
                   prior_intercept = cauchy(0, 2/3, autoscale = TRUE), ## constant term + random intercepts
                   prior = cauchy(0, 2/3, autoscale = TRUE),
                   ## error sd
                   prior_aux = cauchy(0, 2/3, autoscale = TRUE),
                   ## random effects covariance
                   prior_covariance =  decov(regularization = 1, 
                                             concentration = 1, 
                                             shape = 1, scale = 1),
                   seed = 1159518215,
                   adapt_delta = .99,
                   chains = 4, iter = 4000)
    toc()
    write_rds(model, model_file)
} else {
    model = read_rds(model_file)
}

prior_summary(model)
## Priors for model 'model' 
## ------
## Intercept (after predictors centered)
##  ~ cauchy(location = 0, scale = 0.67)
## 
## Coefficients
##   Specified prior:
##     ~ cauchy(location = [0,0,0,...], scale = [0.67,0.67,0.67,...])
##   Adjusted prior:
##     ~ cauchy(location = [0,0,0,...], scale = [ 1.46,10.61, 1.74,...])
## 
## Covariance
##  ~ decov(reg. = 1, conc. = 1, shape = 1, scale = 1)
## ------
## See help('prior_summary.stanreg') for more details
## Check ESS and Rhat
## Rhats all look good.  ESS a little low for grad years + some sigmas
model %>%
    summary() %>%
    as.data.frame() %>%
    rownames_to_column('parameter') %>%
    select(parameter, n_eff, Rhat) %>%
    # knitr::kable()
    ggplot(aes(n_eff, Rhat, label = parameter)) +
    geom_point() +
    geom_vline(xintercept = 3000) +
    geom_hline(yintercept = 1.01)

if (require(plotly)) {
    plotly::ggplotly()    
}
## Variables w/ fewer than 3000 effective draws
## covariance on random intercepts; log posterior
model %>% 
    summary() %>% 
    as.data.frame() %>% 
    rownames_to_column('parameter') %>% 
    as_tibble() %>% 
    filter(n_eff < 3000) %>% 
    select(parameter, n_eff)
## Check predictions
pp_check(model, nreps = 200)

pp_check(model, nreps = 200, plotfun = 'ppc_bars')

## <https://arxiv.org/pdf/1605.01311.pdf>
pp_check(model, nreps = 200, plotfun = 'ppc_rootogram')

pp_check(model, nreps = 200, plotfun = 'ppc_rootogram', 
         style = 'hanging')

## 90% HPD posterior intervals
estimates = posterior_estimates(model, prob = .9)

estimates
## Estimates plot
estimates %>% 
    filter(entity != 'intercept', 
           group != 'community',
           group != 'placement_year', 
           term != 'gendero') %>% 
    ## posterior_estimates() already exponentiates estimates
    mutate_if(is.numeric, ~ . - 1) %>% 
    ggplot(aes(x = level, y = estimate, 
           ymin = conf.low, ymax = conf.high, 
           color = group)) +
    geom_hline(yintercept = 0, linetype = 'dashed') +
    geom_pointrange(size = 1.5, fatten = 1.5) + 
    scale_color_viridis_d(name = 'covariate\ngroup') +
    xlab('') + #ylab('') +
    scale_y_continuous(labels = scales::percent_format(), 
                       name = '') +
    coord_flip(ylim = c(-1, 1.75)) +
    facet_wrap(~ entity, scales = 'free') +
    theme(legend.position = 'bottom')

ggsave(str_c(output_folder, 'estimates.png'), 
       width = 6, height = 4, 
       scale = 1.5)
ggsave(str_c(paper_folder, 'fig_reg_estimates.png'), 
       width = 6, height = 4, 
       scale = 1.5)

estimates %>% 
    filter(entity != 'intercept', 
           group != 'community',
           group != 'placement_year') %>% 
    select(group, level, estimate, conf.low, conf.high) %>% 
    mutate_if(is.factor, as.character) %>% 
    arrange(group, level) %>% 
    knitr::kable(format = 'latex', 
                 digits = 2,
                 booktabs = TRUE, 
                 label = 'estimates', 
                 caption = 'Estimated regression coefficients.  Lower and upper columns give the left and right endpoints, respectively, of the centered 90\\% posterior intervals.') %>% 
    write_file(path = str_c(output_folder, 'estimates.tex'))
## Warning: The `path` argument of `write_file()` is deprecated as of readr 1.4.0.
## Please use the `file` argument instead.
## Marginal effects for gender and prestige ----
## <https://stackoverflow.com/questions/45037485/calculating-marginal-effects-in-binomial-logit-using-rstanarm>
marginals = function (dataf, model, variable, 
                      ref_value = 0L, 
                      alt_value = 1L) {
    variable = enquo(variable)
    
    all_0 = mutate(dataf, !!variable := ref_value)
    all_1 = mutate(dataf, !!variable := alt_value)
    
    pred_0 = posterior_epred(model, newdata = all_0)
    pred_1 = posterior_epred(model, newdata = all_1)
    
    marginal_effect = pred_1 - pred_0
    return(marginal_effect)
}

marginals_gender = individual_df %>% 
    ## posterior_linpred raises an error when there are any NAs, even in columns that aren't used by the model
    select(-city, -state) %>% 
    marginals(model, gender, 
              ref_value = 'm', 
              alt_value = 'w')

apply(marginals_gender, 1, mean) %>% 
    quantile(probs = c(.05, .5, .95))
##        5%       50%       95% 
## 0.1016608 0.1341106 0.1656972
#         5%        50%        95% 
# 0.06387576 0.10352822 0.14303032


marginals_prestige = individual_df %>% 
    select(-city, -state) %>% 
    marginals(model, prestige, 'low-prestige', 'high-prestige')

apply(marginals_prestige, 1, mean) %>% 
    quantile(probs = c(.05, .5, .95))
##         5%        50%        95% 
## 0.07131573 0.10708859 0.14312563
#          5%        50%        95% 
#  0.07601945 0.11718961 0.15762968

marginals_canada = individual_df %>% 
    select(-city, -state) %>% 
    marginals(model, country, 'U.S.', 'Canada') %>% 
    apply(1, mean) %>% 
    quantile(probs = c(.05, .5, .95))
marginals_canada
##          5%         50%         95% 
## -0.22438992 -0.16013292 -0.09444002
## Schools in certain communities ----
# comms_of_interest = c(3, 5, 12, 37, 54, 
#                          8, 27, 38, 43) %>% 
#     as.character()
# 
# univ_df %>% 
#     filter(community %in% comms_of_interest, 
#            total_placements > 0) %>% 
#     select(community, name = univ_name, 
#            total_placements, perm_placement_rate) %>% 
#     mutate(community = fct_relevel(community, comms_of_interest), 
#            perm_placement_rate = scales::percent_format()(perm_placement_rate)) %>% 
#     arrange(community, name) %>% 
#     knitr::kable(format = 'latex', 
#                  # digits = 2,
#                  booktabs = TRUE, 
#                  label = 'comms', 
#                  caption = 'Universities in selected topological communities.  Only universities with at least 1 placement in the data are shown.') %>% 
#     write_file(path = str_c(output_folder, 'comms.tex'))


## Reproducibility ----
sessionInfo()
## R version 4.1.0 (2021-05-18)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] plotly_4.9.4.1   argparse_2.0.4   assertthat_0.2.1 tictoc_1.0.1    
##  [5] rstanarm_2.21.1  Rcpp_1.0.7       broom_0.7.9      cowplot_1.1.1   
##  [9] readxl_1.3.1     tidylog_1.0.2    forcats_0.5.1    stringr_1.4.0   
## [13] dplyr_1.0.7      purrr_0.3.4      readr_2.0.1      tidyr_1.1.4     
## [17] tibble_3.1.5     ggplot2_3.3.5    tidyverse_1.3.1 
## 
## loaded via a namespace (and not attached):
##   [1] backports_1.2.1      plyr_1.8.6           igraph_1.2.6        
##   [4] lazyeval_0.2.2       splines_4.1.0        findpython_1.0.7    
##   [7] crosstalk_1.1.1      rstantools_2.1.1     inline_0.3.19       
##  [10] digest_0.6.27        htmltools_0.5.2      rsconnect_0.8.24    
##  [13] fansi_0.5.0          magrittr_2.0.1       tzdb_0.1.2          
##  [16] modelr_0.1.8         RcppParallel_5.1.4   matrixStats_0.60.0  
##  [19] xts_0.12.1           prettyunits_1.1.1    colorspace_2.0-2    
##  [22] rvest_1.0.1          haven_2.4.1          xfun_0.24           
##  [25] callr_3.7.0          crayon_1.4.1         jsonlite_1.7.2      
##  [28] lme4_1.1-27.1        survival_3.2-11      zoo_1.8-9           
##  [31] glue_1.4.2           gtable_0.3.0         V8_3.4.2            
##  [34] pkgbuild_1.2.0       rstan_2.21.2         scales_1.1.1        
##  [37] DBI_1.1.1            miniUI_0.1.1.1       viridisLite_0.4.0   
##  [40] xtable_1.8-4         clisymbols_1.2.0     stats4_4.1.0        
##  [43] StanHeaders_2.21.0-7 DT_0.19              htmlwidgets_1.5.4   
##  [46] httr_1.4.2           threejs_0.3.3        RColorBrewer_1.1-2  
##  [49] ellipsis_0.3.2       pkgconfig_2.0.3      loo_2.4.1           
##  [52] farver_2.1.0         sass_0.4.0           dbplyr_2.1.1        
##  [55] utf8_1.2.2           tidyselect_1.1.1     labeling_0.4.2      
##  [58] rlang_0.4.11         reshape2_1.4.4       later_1.3.0         
##  [61] munsell_0.5.0        cellranger_1.1.0     tools_4.1.0         
##  [64] cli_3.0.1            generics_0.1.0       ggridges_0.5.3      
##  [67] evaluate_0.14        fastmap_1.1.0        yaml_2.2.1          
##  [70] processx_3.5.2       knitr_1.33           fs_1.5.0            
##  [73] nlme_3.1-152         mime_0.11            xml2_1.3.2          
##  [76] compiler_4.1.0       bayesplot_1.8.1      shinythemes_1.2.0   
##  [79] rstudioapi_0.13      curl_4.3.2           reprex_2.0.0        
##  [82] bslib_0.3.0          stringi_1.7.4        highr_0.9           
##  [85] ps_1.6.0             lattice_0.20-44      Matrix_1.3-4        
##  [88] nloptr_1.2.2.2       markdown_1.1         shinyjs_2.0.0       
##  [91] vctrs_0.3.8          pillar_1.6.3         lifecycle_1.0.1     
##  [94] jquerylib_0.1.4      data.table_1.14.0    httpuv_1.6.3        
##  [97] R6_2.5.1             promises_1.2.0.1     gridExtra_2.3       
## [100] codetools_0.2-18     boot_1.3-28          colourpicker_1.1.0  
## [103] MASS_7.3-54          gtools_3.9.2         withr_2.4.2         
## [106] shinystan_2.5.0      broom.mixed_0.2.7    mgcv_1.8-36         
## [109] parallel_4.1.0       hms_1.1.0            grid_4.1.0          
## [112] coda_0.19-4          minqa_1.2.4          rmarkdown_2.9       
## [115] shiny_1.6.0          lubridate_1.7.10     base64enc_0.1-3     
## [118] dygraphs_1.1.1.6